/*
 * Copyright 2021 CloudWeGo Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package nphttp2

import (
	"bytes"
	"container/list"
	"context"
	"errors"
	"fmt"
	"net"
	"runtime/debug"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"github.com/cloudwego/netpoll"

	"github.com/cloudwego/kitex/pkg/endpoint"
	"github.com/cloudwego/kitex/pkg/gofunc"
	"github.com/cloudwego/kitex/pkg/kerrors"
	"github.com/cloudwego/kitex/pkg/klog"
	"github.com/cloudwego/kitex/pkg/remote"
	"github.com/cloudwego/kitex/pkg/remote/codec"
	"github.com/cloudwego/kitex/pkg/remote/codec/grpc"
	"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes"
	grpcTransport "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc"
	"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status"
	"github.com/cloudwego/kitex/pkg/rpcinfo"
	"github.com/cloudwego/kitex/pkg/serviceinfo"
	"github.com/cloudwego/kitex/pkg/stats"
	"github.com/cloudwego/kitex/pkg/streaming"
	"github.com/cloudwego/kitex/transport"
)

type svrTransHandlerFactory struct{}

// NewSvrTransHandlerFactory ...
func NewSvrTransHandlerFactory() remote.ServerTransHandlerFactory {
	return &svrTransHandlerFactory{}
}

func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) {
	return newSvrTransHandler(opt)
}

func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) {
	return &svrTransHandler{
		opt:         opt,
		svcSearcher: opt.SvcSearcher,
		codec:       grpc.NewGRPCCodec(grpc.WithThriftCodec(opt.PayloadCodec)),
		li:          list.New(),
	}, nil
}

var _ remote.ServerTransHandler = &svrTransHandler{}

type svrTransHandler struct {
	opt         *remote.ServerOption
	svcSearcher remote.ServiceSearcher
	inkHdlFunc  endpoint.Endpoint
	codec       remote.Codec

	mu sync.Mutex
	// maintain all active server transports
	li *list.List
}

var prefaceReadAtMost = func() int {
	// min(len(ClientPreface), len(flagBuf))
	// len(flagBuf) = 2 * codec.Size32
	if 2*codec.Size32 < grpcTransport.ClientPrefaceLen {
		return 2 * codec.Size32
	}
	return grpcTransport.ClientPrefaceLen
}()

func (t *svrTransHandler) ProtocolMatch(ctx context.Context, conn net.Conn) error {
	// Check the validity of client preface.
	// FIXME: should not rely on netpoll.Reader
	if withReader, ok := conn.(interface{ Reader() netpoll.Reader }); ok {
		if npReader := withReader.Reader(); npReader != nil {
			// read at most avoid block
			preface, err := npReader.Peek(prefaceReadAtMost)
			if err != nil {
				return err
			}
			if len(preface) >= prefaceReadAtMost && bytes.Equal(preface[:prefaceReadAtMost], grpcTransport.ClientPreface[:prefaceReadAtMost]) {
				return nil
			}
		}
	}
	return errors.New("error protocol not match")
}

func (t *svrTransHandler) Write(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) {
	buf := newBuffer(conn.(*serverConn))
	defer buf.Release(err)

	if err = t.codec.Encode(ctx, msg, buf); err != nil {
		return ctx, err
	}
	return ctx, buf.Flush()
}

func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) {
	buf := newBuffer(conn.(*serverConn))
	defer buf.Release(err)

	err = t.codec.Decode(ctx, msg, buf)
	return ctx, err
}

// 只 return write err
func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error {
	svrTrans := ctx.Value(ctxKeySvrTransport).(*SvrTrans)
	tr := svrTrans.tr
	tr.HandleStreams(func(s *grpcTransport.Stream) {
		atomic.AddInt32(&svrTrans.handlerNum, 1)
		gofunc.GoFunc(ctx, func() {
			t.handleFunc(s, svrTrans, conn)
			atomic.AddInt32(&svrTrans.handlerNum, -1)
		})
	}, func(ctx context.Context, method string) context.Context {
		return ctx
	})
	return nil
}

func (t *svrTransHandler) handleFunc(s *grpcTransport.Stream, svrTrans *SvrTrans, conn net.Conn) {
	tr := svrTrans.tr
	ri := svrTrans.pool.Get().(rpcinfo.RPCInfo)
	rCtx := rpcinfo.NewCtxWithRPCInfo(s.Context(), ri)
	defer func() {
		// reset rpcinfo for performance (PR #584)
		if rpcinfo.PoolEnabled() {
			ri = t.opt.InitOrResetRPCInfoFunc(ri, conn.RemoteAddr())
			svrTrans.pool.Put(ri)
		}
	}()

	ink := ri.Invocation().(rpcinfo.InvocationSetter)
	sm := s.Method()
	if sm != "" && sm[0] == '/' {
		sm = sm[1:]
	}
	pos := strings.LastIndex(sm, "/")
	if pos == -1 {
		errDesc := fmt.Sprintf("malformed method name, method=%q", s.Method())
		tr.WriteStatus(s, status.New(codes.Internal, errDesc))
		return
	}
	methodName := sm[pos+1:]
	ink.SetMethodName(methodName)

	if mutableTo := rpcinfo.AsMutableEndpointInfo(ri.To()); mutableTo != nil {
		if err := mutableTo.SetMethod(methodName); err != nil {
			errDesc := fmt.Sprintf("setMethod failed in streaming, method=%s, error=%s", methodName, err.Error())
			_ = tr.WriteStatus(s, status.New(codes.Internal, errDesc))
			return
		}
	}

	var serviceName string
	idx := strings.LastIndex(sm[:pos], ".")
	if idx == -1 {
		ink.SetPackageName("")
		serviceName = sm[0:pos]
	} else {
		ink.SetPackageName(sm[:idx])
		serviceName = sm[idx+1 : pos]
	}
	ink.SetServiceName(serviceName)

	// set grpc transport flag before execute metahandler
	rpcinfo.AsMutableRPCConfig(ri.Config()).SetTransportProtocol(transport.GRPC)
	var err error
	for _, shdlr := range t.opt.StreamingMetaHandlers {
		rCtx, err = shdlr.OnReadStream(rCtx)
		if err != nil {
			tr.WriteStatus(s, convertStatus(err))
			return
		}
	}
	rCtx = t.startTracer(rCtx, ri)
	defer func() {
		panicErr := recover()
		if panicErr != nil {
			if conn != nil {
				klog.CtxErrorf(rCtx, "KITEX: gRPC panic happened, close conn, remoteAddress=%s, error=%s\nstack=%s", conn.RemoteAddr(), panicErr, string(debug.Stack()))
			} else {
				klog.CtxErrorf(rCtx, "KITEX: gRPC panic happened, error=%v\nstack=%s", panicErr, string(debug.Stack()))
			}
		}
		t.finishTracer(rCtx, ri, err, panicErr)
	}()

	// set recv grpc compressor at server to decode the pack from client
	remote.SetRecvCompressor(ri, s.RecvCompress())
	// set send grpc compressor at server to encode reply pack
	remote.SetSendCompressor(ri, s.SendCompress())

	svcInfo := t.svcSearcher.SearchService(serviceName, methodName, true)
	var methodInfo serviceinfo.MethodInfo
	if svcInfo != nil {
		methodInfo = svcInfo.MethodInfo(methodName)
	}

	rawStream := &stream{
		ctx:     rCtx,
		svcInfo: svcInfo,
		conn:    newServerConn(tr, s),
		handler: t,
	}
	// inject rawStream so that GetServerConn only relies on it
	rCtx = context.WithValue(rCtx, serverConnKey{}, rawStream)
	st := newStreamWithMiddleware(rawStream, t.opt.RecvEndpoint, t.opt.SendEndpoint)
	// bind stream into ctx, in order to let user set header and trailer by provided api in meta_api.go
	rCtx = streaming.NewCtxWithStream(rCtx, st)
	// GetServerConn could retrieve rawStream by Stream.Context().Value(serverConnKey{})
	rawStream.ctx = rCtx

	if methodInfo == nil {
		unknownServiceHandlerFunc := t.opt.GRPCUnknownServiceHandler
		if unknownServiceHandlerFunc != nil {
			rpcinfo.Record(rCtx, ri, stats.ServerHandleStart, nil)
			err = unknownServiceHandlerFunc(rCtx, methodName, st)
			if err != nil {
				err = kerrors.ErrBiz.WithCause(err)
			}
		} else {
			if svcInfo == nil {
				err = remote.NewTransErrorWithMsg(remote.UnknownService, fmt.Sprintf("unknown service %s", serviceName))
			} else {
				err = remote.NewTransErrorWithMsg(remote.UnknownMethod, fmt.Sprintf("unknown method %s", methodName))
			}
		}
	} else {
		if streaming.UnaryCompatibleMiddleware(methodInfo.StreamingMode(), t.opt.CompatibleMiddlewareForUnary) {
			// making streaming unary APIs capable of using the same server middleware as non-streaming APIs
			// note: rawStream skips recv/send middleware for unary API requests to avoid confusion
			err = invokeStreamUnaryHandler(rCtx, rawStream, methodInfo, t.inkHdlFunc, ri)
		} else {
			err = t.inkHdlFunc(rCtx, &streaming.Args{Stream: st}, nil)
		}
	}

	if err != nil {
		tr.WriteStatus(s, convertStatus(err))
		t.OnError(rCtx, err, conn)
		return
	}
	if bizStatusErr := ri.Invocation().BizStatusErr(); bizStatusErr != nil {
		var st *status.Status
		if sterr, ok := bizStatusErr.(status.Iface); ok {
			st = sterr.GRPCStatus()
		} else {
			st = status.New(codes.Internal, bizStatusErr.BizMessage())
		}
		s.SetBizStatusErr(bizStatusErr)
		tr.WriteStatus(s, st)
		return
	}
	tr.WriteStatus(s, status.New(codes.OK, ""))
}

// invokeStreamUnaryHandler allows unary APIs over HTTP2 to use the same server middleware as non-streaming APIs.
// For thrift unary APIs over HTTP2, it's enabled by default.
// For grpc(protobuf) unary APIs, it's disabled by default to keep backward compatibility.
func invokeStreamUnaryHandler(ctx context.Context, st streaming.Stream, mi serviceinfo.MethodInfo,
	handler endpoint.Endpoint, ri rpcinfo.RPCInfo,
) (err error) {
	realArgs, realResp := mi.NewArgs(), mi.NewResult()
	if err = st.RecvMsg(realArgs); err != nil {
		return err
	}
	if err = handler(ctx, realArgs, realResp); err != nil {
		return err
	}
	if ri != nil && ri.Invocation().BizStatusErr() != nil {
		// BizError: do not send the message
		return nil
	}
	return st.SendMsg(realResp)
}

// msg 是解码后的实例，如 Arg 或 Result, 触发上层处理，用于异步 和 服务端处理
func (t *svrTransHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) {
	panic("unimplemented")
}

type svrTransKey int

const (
	ctxKeySvrTransport svrTransKey = 1
	// align with default exitWaitTime
	defaultGraceTime time.Duration = 5 * time.Second
	// max poll time to check whether all the transports have finished
	// A smaller poll time will not affect performance and will speed up our unit tests
	defaultMaxPollTime = 50 * time.Millisecond
)

type SvrTrans struct {
	tr   grpcTransport.ServerTransport
	pool *sync.Pool // value is rpcInfo
	elem *list.Element
	// num of active handlers
	handlerNum int32
}

// 新连接建立时触发，主要用于服务端，对应 netpoll onPrepare
func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) {
	// set readTimeout to infinity to avoid streaming break
	// use keepalive to check the health of connection
	if npConn, ok := conn.(netpoll.Connection); ok {
		npConn.SetReadTimeout(grpcTransport.Infinity)
	} else {
		conn.SetReadDeadline(time.Now().Add(grpcTransport.Infinity))
	}

	tr, err := grpcTransport.NewServerTransport(ctx, conn, t.opt.GRPCCfg)
	if err != nil {
		return nil, err
	}
	pool := &sync.Pool{
		New: func() interface{} {
			// init rpcinfo
			ri := t.opt.InitOrResetRPCInfoFunc(nil, conn.RemoteAddr())
			return ri
		},
	}
	svrTrans := &SvrTrans{tr: tr, pool: pool}
	t.mu.Lock()
	elem := t.li.PushBack(svrTrans)
	t.mu.Unlock()
	svrTrans.elem = elem
	ctx = context.WithValue(ctx, ctxKeySvrTransport, svrTrans)
	return ctx, nil
}

// 连接关闭时回调
func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) {
	svrTrans := ctx.Value(ctxKeySvrTransport).(*SvrTrans)
	tr := svrTrans.tr
	t.mu.Lock()
	t.li.Remove(svrTrans.elem)
	t.mu.Unlock()
	tr.Close()
}

// 传输层 error 回调
func (t *svrTransHandler) OnError(ctx context.Context, err error, conn net.Conn) {
	var de *kerrors.DetailedError
	if ok := errors.As(err, &de); ok && de.Stack() != "" {
		klog.CtxErrorf(ctx, "KITEX: processing gRPC request error, remoteAddr=%s, error=%s\nstack=%s", conn.RemoteAddr(), err.Error(), de.Stack())
	} else {
		klog.CtxErrorf(ctx, "KITEX: processing gRPC request error, remoteAddr=%s, error=%s", conn.RemoteAddr(), err.Error())
	}
}

func (t *svrTransHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) {
	t.inkHdlFunc = inkHdlFunc
}

func (t *svrTransHandler) SetPipeline(p *remote.TransPipeline) {
}

func (t *svrTransHandler) GracefulShutdown(ctx context.Context) error {
	klog.Info("KITEX: gRPC GracefulShutdown starts")
	defer func() {
		klog.Info("KITEX: gRPC GracefulShutdown ends")
	}()
	t.mu.Lock()
	for elem := t.li.Front(); elem != nil; elem = elem.Next() {
		svrTrans := elem.Value.(*SvrTrans)
		svrTrans.tr.Drain()
	}
	t.mu.Unlock()
	graceTime, pollTime := parseGraceAndPollTime(ctx)
	graceTimer := time.NewTimer(graceTime)
	defer graceTimer.Stop()
	pollTick := time.NewTicker(pollTime)
	defer pollTick.Stop()
	for {
		select {
		case <-pollTick.C:
			var activeNums int32
			t.mu.Lock()
			for elem := t.li.Front(); elem != nil; elem = elem.Next() {
				svrTrans := elem.Value.(*SvrTrans)
				activeNums += atomic.LoadInt32(&svrTrans.handlerNum)
			}
			t.mu.Unlock()
			if activeNums == 0 {
				return nil
			}
		case <-graceTimer.C:
			klog.Info("KITEX: gRPC triggers Close")
			t.mu.Lock()
			for elem := t.li.Front(); elem != nil; elem = elem.Next() {
				svrTrans := elem.Value.(*SvrTrans)
				svrTrans.tr.Close()
			}
			t.mu.Unlock()
			return nil
		}
	}
}

func parseGraceAndPollTime(ctx context.Context) (graceTime, pollTime time.Duration) {
	graceTime = defaultGraceTime
	deadline, ok := ctx.Deadline()
	if ok {
		if customTime := time.Until(deadline); customTime > 0 {
			graceTime = customTime
		}
	}

	pollTime = graceTime / 10
	if pollTime > defaultMaxPollTime {
		pollTime = defaultMaxPollTime
	}
	return
}

func (t *svrTransHandler) startTracer(ctx context.Context, ri rpcinfo.RPCInfo) context.Context {
	c := t.opt.TracerCtl.DoStart(ctx, ri)
	return c
}

func (t *svrTransHandler) finishTracer(ctx context.Context, ri rpcinfo.RPCInfo, err error, panicErr interface{}) {
	rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats())
	if rpcStats == nil {
		return
	}
	if panicErr != nil {
		rpcStats.SetPanicked(panicErr)
	}
	t.opt.TracerCtl.DoFinish(ctx, ri, err)
	rpcStats.Reset()
}
